class Solution:
    def punishmentNumber(self, n: int) -> int:

        def f(tot,i):
            if tot == i:
                return True
            s = str(tot)
            for j in range(len(s) - 1):
                k = int(s[:j + 1])
                if k <= i:
                    if f(int(s[j + 1:]),i - k):
                        return True
            return False

        ans = 0
        for i in range(1,n+1):
            if f(i*i,i):
                ans+=i*i
        return ans
    
def f(tot,i):
    if tot == i:
        return True
    s = str(tot)
    for j in range(len(s) - 1):
        k = int(s[:j + 1])
        if k <= i:
            if f(int(s[j + 1:]),i - k):
                return True
    return False
print(f(36*36,36))
